summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py89
1 files changed, 86 insertions, 3 deletions
diff --git a/tests/server.py b/tests/server.py
index 420ec4e088..7bee58dff1 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -98,7 +98,7 @@ class FakeSite:
         return FakeLogger()
 
 
-def make_request(method, path, content=b"", access_token=None):
+def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
     """
     Make a web request using the given method and path, feed it the
     content, and return the Request and the Channel underneath.
@@ -120,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None):
     site = FakeSite()
     channel = FakeChannel()
 
-    req = SynapseRequest(site, channel)
+    req = request(site, channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
 
     if access_token:
         req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
 
-    req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
+    if content:
+        req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+
     req.requestReceived(method, path, b"1.1")
 
     return req, channel
@@ -280,3 +282,84 @@ def get_clock():
     clock = ThreadedMemoryReactorClock()
     hs_clock = Clock(clock)
     return (clock, hs_clock)
+
+
+@attr.s
+class FakeTransport(object):
+    """
+    A twisted.internet.interfaces.ITransport implementation which sends all its data
+    straight into an IProtocol object: it exists to connect two IProtocols together.
+
+    To use it, instantiate it with the receiving IProtocol, and then pass it to the
+    sending IProtocol's makeConnection method:
+
+        server = HTTPChannel()
+        client.makeConnection(FakeTransport(server, self.reactor))
+
+    If you want bidirectional communication, you'll need two instances.
+    """
+
+    other = attr.ib()
+    """The Protocol object which will receive any data written to this transport.
+
+    :type: twisted.internet.interfaces.IProtocol
+    """
+
+    _reactor = attr.ib()
+    """Test reactor
+
+    :type: twisted.internet.interfaces.IReactorTime
+    """
+
+    disconnecting = False
+    buffer = attr.ib(default=b'')
+    producer = attr.ib(default=None)
+
+    def getPeer(self):
+        return None
+
+    def getHost(self):
+        return None
+
+    def loseConnection(self):
+        self.disconnecting = True
+
+    def abortConnection(self):
+        self.disconnecting = True
+
+    def pauseProducing(self):
+        self.producer.pauseProducing()
+
+    def unregisterProducer(self):
+        if not self.producer:
+            return
+
+        self.producer = None
+
+    def registerProducer(self, producer, streaming):
+        self.producer = producer
+        self.producerStreaming = streaming
+
+        def _produce():
+            d = self.producer.resumeProducing()
+            d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
+
+        if not streaming:
+            self._reactor.callLater(0.0, _produce)
+
+    def write(self, byt):
+        self.buffer = self.buffer + byt
+
+        def _write():
+            if getattr(self.other, "transport") is not None:
+                self.other.dataReceived(self.buffer)
+                self.buffer = b""
+                return
+
+            self._reactor.callLater(0.0, _write)
+
+        _write()
+
+    def writeSequence(self, seq):
+        for x in seq:
+            self.write(x)