diff options
Diffstat (limited to 'tests/server.py')
-rw-r--r-- | tests/server.py | 241 |
1 files changed, 241 insertions, 0 deletions
diff --git a/tests/server.py b/tests/server.py new file mode 100644 index 0000000000..7dbdb7f8ea --- /dev/null +++ b/tests/server.py @@ -0,0 +1,241 @@ +import json +from io import BytesIO + +from six import text_type + +import attr + +from twisted.internet import address, threads +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.http.site import SynapseRequest +from synapse.util import Clock + +from tests.utils import setup_test_homeserver as _sth + + +@attr.s +class FakeChannel(object): + """ + A fake Twisted Web Channel (the part that interfaces with the + wire). + """ + + result = attr.ib(default=attr.Factory(dict)) + _producer = None + + @property + def json_body(self): + if not self.result: + raise Exception("No result yet.") + return json.loads(self.result["body"].decode('utf8')) + + @property + def code(self): + if not self.result: + raise Exception("No result yet.") + return int(self.result["code"]) + + def writeHeaders(self, version, code, reason, headers): + self.result["version"] = version + self.result["code"] = code + self.result["reason"] = reason + self.result["headers"] = headers + + def write(self, content): + if "body" not in self.result: + self.result["body"] = b"" + + self.result["body"] += content + + def registerProducer(self, producer, streaming): + self._producer = producer + + def unregisterProducer(self): + if self._producer is None: + return + + self._producer = None + + def requestDone(self, _self): + self.result["done"] = True + + def getPeer(self): + # We give an address so that getClientIP returns a non null entry, + # causing us to record the MAU + return address.IPv4Address(b"TCP", "127.0.0.1", 3423) + + def getHost(self): + return None + + @property + def transport(self): + return self + + +class FakeSite: + """ + A fake Twisted Web Site, with mocks of the extra things that + Synapse adds. + """ + + server_version_string = b"1" + site_tag = "test" + + @property + def access_logger(self): + class FakeLogger: + def info(self, *args, **kwargs): + pass + + return FakeLogger() + + +def make_request(method, path, content=b"", access_token=None): + """ + Make a web request using the given method and path, feed it the + content, and return the Request and the Channel underneath. + """ + if not isinstance(method, bytes): + method = method.encode('ascii') + + if not isinstance(path, bytes): + path = path.encode('ascii') + + # Decorate it to be the full path + if not path.startswith(b"/_matrix"): + path = b"/_matrix/client/r0/" + path + path = path.replace(b"//", b"/") + + if isinstance(content, text_type): + content = content.encode('utf8') + + site = FakeSite() + channel = FakeChannel() + + req = SynapseRequest(site, channel) + req.process = lambda: b"" + req.content = BytesIO(content) + + if access_token: + req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) + + req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") + req.requestReceived(method, path, b"1.1") + + return req, channel + + +def wait_until_result(clock, request, timeout=100): + """ + Wait until the request is finished. + """ + clock.run() + x = 0 + + while not request.finished: + + # If there's a producer, tell it to resume producing so we get content + if request._channel._producer: + request._channel._producer.resumeProducing() + + x += 1 + + if x > timeout: + raise Exception("Timed out waiting for request to finish.") + + clock.advance(0.1) + + +def render(request, resource, clock): + request.render(resource) + wait_until_result(clock, request) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + return d + + +def setup_test_homeserver(cleanup_func, *args, **kwargs): + """ + Set up a synchronous test server, driven by the reactor used by + the homeserver. + """ + d = _sth(cleanup_func, *args, **kwargs).result + + if isinstance(d, Failure): + d.raiseException() + + # Make the thread pool synchronous. + clock = d.get_clock() + pool = d.get_db_pool() + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) + + pool.runWithConnection = runWithConnection + pool.runInteraction = runInteraction + + class ThreadPool: + """ + Threadless thread pool. + """ + + def start(self): + pass + + def stop(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + clock._reactor.callLater(0, d.callback, True) + return d + + clock.threadpool = ThreadPool() + pool.threadpool = ThreadPool() + return d + + +def get_clock(): + clock = ThreadedMemoryReactorClock() + hs_clock = Clock(clock) + return (clock, hs_clock) |