diff options
Diffstat (limited to 'tests/server.py')
-rw-r--r-- | tests/server.py | 93 |
1 files changed, 61 insertions, 32 deletions
diff --git a/tests/server.py b/tests/server.py index e397ebe8fa..1644710aa0 100644 --- a/tests/server.py +++ b/tests/server.py @@ -20,6 +20,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.server import Site from synapse.http.site import SynapseRequest from synapse.util import Clock @@ -42,6 +43,7 @@ class FakeChannel(object): wire). """ + site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @@ -161,7 +163,11 @@ def make_request( path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand - if shorthand and not path.startswith(b"/_matrix"): + if ( + shorthand + and not path.startswith(b"/_matrix") + and not path.startswith(b"/_synapse") + ): path = b"/_matrix/client/r0/" + path path = path.replace(b"//", b"/") @@ -172,9 +178,9 @@ def make_request( content = content.encode("utf8") site = FakeSite() - channel = FakeChannel(reactor) + channel = FakeChannel(site, reactor) - req = request(site, channel) + req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) req.postpath = list(map(unquote, path[1:].split(b"/"))) @@ -298,41 +304,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): @@ -375,6 +382,7 @@ class FakeTransport(object): disconnecting = False disconnected = False + connected = True buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) @@ -391,11 +399,25 @@ class FakeTransport(object): self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) - self.disconnected = True + + # if we still have data to write, delay until that is done + if self.buffer: + logger.info( + "FakeTransport: Delaying disconnect until buffer is flushed" + ) + else: + self.connected = False + self.disconnected = True def abortConnection(self): logger.info("FakeTransport: abortConnection()") - self.loseConnection() + + if not self.disconnecting: + self.disconnecting = True + if self._protocol: + self._protocol.connectionLost(None) + + self.disconnected = True def pauseProducing(self): if not self.producer: @@ -426,6 +448,9 @@ class FakeTransport(object): self._reactor.callLater(0.0, _produce) def write(self, byt): + if self.disconnecting: + raise Exception("Writing to disconnecting FakeTransport") + self.buffer = self.buffer + byt # always actually do the write asynchronously. Some protocols (notably the @@ -470,6 +495,10 @@ class FakeTransport(object): if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + if not self.buffer and self.disconnecting: + logger.info("FakeTransport: Buffer now empty, completing disconnect") + self.disconnected = True + def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: """ |