summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py68
1 files changed, 38 insertions, 30 deletions
diff --git a/tests/server.py b/tests/server.py
index 8f89f4a83d..c15a47f2a4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -182,7 +182,8 @@ def make_request(
 
     if federation_auth_origin is not None:
         req.requestHeaders.addRawHeader(
-            b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
+            b"Authorization",
+            b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
         )
 
     if content:
@@ -226,6 +227,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     """
 
     def __init__(self):
+        self.threadpool = ThreadPool(self)
+
         self._udp = []
         lookups = self.lookups = {}
 
@@ -233,7 +236,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         class FakeResolver(object):
             def getHostByName(self, name, timeout=None):
                 if name not in lookups:
-                    return fail(DNSLookupError("OH NO: unknown %s" % (name, )))
+                    return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
                 return succeed(lookups[name])
 
         self.nameResolver = SimpleResolverComplexifier(FakeResolver())
@@ -254,6 +257,37 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self.callLater(0, d.callback, True)
         return d
 
+    def getThreadPool(self):
+        return self.threadpool
+
+
+class ThreadPool:
+    """
+    Threadless thread pool.
+    """
+
+    def __init__(self, reactor):
+        self._reactor = reactor
+
+    def start(self):
+        pass
+
+    def stop(self):
+        pass
+
+    def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
+        def _(res):
+            if isinstance(res, Failure):
+                onResult(False, res)
+            else:
+                onResult(True, res)
+
+        d = Deferred()
+        d.addCallback(lambda x: function(*args, **kwargs))
+        d.addBoth(_)
+        self._reactor.callLater(0, d.callback, True)
+        return d
+
 
 def setup_test_homeserver(cleanup_func, *args, **kwargs):
     """
@@ -289,36 +323,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
             **kwargs
         )
 
-    class ThreadPool:
-        """
-        Threadless thread pool.
-        """
-
-        def start(self):
-            pass
-
-        def stop(self):
-            pass
-
-        def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
-            def _(res):
-                if isinstance(res, Failure):
-                    onResult(False, res)
-                else:
-                    onResult(True, res)
-
-            d = Deferred()
-            d.addCallback(lambda x: function(*args, **kwargs))
-            d.addBoth(_)
-            clock._reactor.callLater(0, d.callback, True)
-            return d
-
-    clock.threadpool = ThreadPool()
-
     if pool:
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
-        pool.threadpool = ThreadPool()
+        pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
     return d
 
@@ -454,6 +462,6 @@ class FakeTransport(object):
             logger.warning("Exception writing to protocol: %s", e)
             return
 
-        self.buffer = self.buffer[len(to_write):]
+        self.buffer = self.buffer[len(to_write) :]
         if self.buffer and self.autoflush:
             self._reactor.callLater(0.0, self.flush)