diff --git a/tests/server.py b/tests/server.py
index a5e57c52fa..b6e0b14e78 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
+ self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
+ def add_tcp_client_callback(self, host, port, callback):
+ """Add a callback that will be invoked when we receive a connection
+ attempt to the given IP/port using `connectTCP`.
+
+ Note that the callback gets run before we return the connection to the
+ client, which means callbacks cannot block while waiting for writes.
+ """
+ self._tcp_callbacks[(host, port)] = callback
+
+ def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+ """Fake L{IReactorTCP.connectTCP}.
+ """
+
+ conn = super().connectTCP(
+ host, port, factory, timeout=timeout, bindAddress=None
+ )
+
+ callback = self._tcp_callbacks.get((host, port))
+ if callback:
+ callback()
+
+ return conn
+
class ThreadPool:
"""
@@ -486,7 +510,7 @@ class FakeTransport(object):
try:
self.other.dataReceived(to_write)
except Exception as e:
- logger.warning("Exception writing to protocol: %s", e)
+ logger.exception("Exception writing to protocol: %s", e)
return
self.buffer = self.buffer[len(to_write) :]
|