summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py57
1 files changed, 49 insertions, 8 deletions
diff --git a/tests/server.py b/tests/server.py
index c611dd6059..c63b2c3100 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,6 +11,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactorClock
 
 from synapse.http.site import SynapseRequest
+from synapse.util import Clock
 
 from tests.utils import setup_test_homeserver as _sth
 
@@ -23,12 +24,19 @@ class FakeChannel(object):
     """
 
     result = attr.ib(default=attr.Factory(dict))
+    _producer = None
 
     @property
     def json_body(self):
         if not self.result:
             raise Exception("No result yet.")
-        return json.loads(self.result["body"])
+        return json.loads(self.result["body"].decode('utf8'))
+
+    @property
+    def code(self):
+        if not self.result:
+            raise Exception("No result yet.")
+        return int(self.result["code"])
 
     def writeHeaders(self, version, code, reason, headers):
         self.result["version"] = version
@@ -42,6 +50,15 @@ class FakeChannel(object):
 
         self.result["body"] += content
 
+    def registerProducer(self, producer, streaming):
+        self._producer = producer
+
+    def unregisterProducer(self):
+        if self._producer is None:
+            return
+
+        self._producer = None
+
     def requestDone(self, _self):
         self.result["done"] = True
 
@@ -79,11 +96,16 @@ def make_request(method, path, content=b""):
     Make a web request using the given method and path, feed it the
     content, and return the Request and the Channel underneath.
     """
+    if not isinstance(method, bytes):
+        method = method.encode('ascii')
+
+    if not isinstance(path, bytes):
+        path = path.encode('ascii')
 
     # Decorate it to be the full path
     if not path.startswith(b"/_matrix"):
         path = b"/_matrix/client/r0/" + path
-        path = path.replace("//", "/")
+        path = path.replace(b"//", b"/")
 
     if isinstance(content, text_type):
         content = content.encode('utf8')
@@ -99,14 +121,19 @@ def make_request(method, path, content=b""):
     return req, channel
 
 
-def wait_until_result(clock, channel, timeout=100):
+def wait_until_result(clock, request, timeout=100):
     """
-    Wait until the channel has a result.
+    Wait until the request is finished.
     """
     clock.run()
     x = 0
 
-    while not channel.result:
+    while not request.finished:
+
+        # If there's a producer, tell it to resume producing so we get content
+        if request._channel._producer:
+            request._channel._producer.resumeProducing()
+
         x += 1
 
         if x > timeout:
@@ -117,13 +144,14 @@ def wait_until_result(clock, channel, timeout=100):
 
 def render(request, resource, clock):
     request.render(resource)
-    wait_until_result(clock, request._channel)
+    wait_until_result(clock, request)
 
 
 class ThreadedMemoryReactorClock(MemoryReactorClock):
     """
     A MemoryReactorClock that supports callFromThread.
     """
+
     def callFromThread(self, callback, *args, **kwargs):
         """
         Make the callback fire in the next reactor iteration.
@@ -134,12 +162,15 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         return d
 
 
-def setup_test_homeserver(*args, **kwargs):
+def setup_test_homeserver(cleanup_func, *args, **kwargs):
     """
     Set up a synchronous test server, driven by the reactor used by
     the homeserver.
     """
-    d = _sth(*args, **kwargs).result
+    d = _sth(cleanup_func, *args, **kwargs).result
+
+    if isinstance(d, Failure):
+        d.raiseException()
 
     # Make the thread pool synchronous.
     clock = d.get_clock()
@@ -172,9 +203,13 @@ def setup_test_homeserver(*args, **kwargs):
         """
         Threadless thread pool.
         """
+
         def start(self):
             pass
 
+        def stop(self):
+            pass
+
         def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
             def _(res):
                 if isinstance(res, Failure):
@@ -191,3 +226,9 @@ def setup_test_homeserver(*args, **kwargs):
     clock.threadpool = ThreadPool()
     pool.threadpool = ThreadPool()
     return d
+
+
+def get_clock():
+    clock = ThreadedMemoryReactorClock()
+    hs_clock = Clock(clock)
+    return (clock, hs_clock)