summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-07-15 15:27:35 +0100
committerGitHub <noreply@github.com>2020-07-15 15:27:35 +0100
commitf13061d5153eca9bd7054d5b89ade41f3a430f3b (patch)
tree827797d67d99b93f41e35fac8ff21cb886c01a66 /tests/server.py
parentConvert E2E key and room key handlers to async/await. (#7851) (diff)
downloadsynapse-f13061d5153eca9bd7054d5b89ade41f3a430f3b.tar.xz
Fix client reader sharding tests (#7853)
* Fix client reader sharding tests

* Newsfile

* Fix typing

* Update changelog.d/7853.misc

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

* Move mocking of http_client to tests

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/tests/server.py b/tests/server.py
index a5e57c52fa..b6e0b14e78 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def __init__(self):
         self.threadpool = ThreadPool(self)
 
+        self._tcp_callbacks = {}
         self._udp = []
         lookups = self.lookups = {}
 
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def getThreadPool(self):
         return self.threadpool
 
+    def add_tcp_client_callback(self, host, port, callback):
+        """Add a callback that will be invoked when we receive a connection
+        attempt to the given IP/port using `connectTCP`.
+
+        Note that the callback gets run before we return the connection to the
+        client, which means callbacks cannot block while waiting for writes.
+        """
+        self._tcp_callbacks[(host, port)] = callback
+
+    def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+        """Fake L{IReactorTCP.connectTCP}.
+        """
+
+        conn = super().connectTCP(
+            host, port, factory, timeout=timeout, bindAddress=None
+        )
+
+        callback = self._tcp_callbacks.get((host, port))
+        if callback:
+            callback()
+
+        return conn
+
 
 class ThreadPool:
     """
@@ -486,7 +510,7 @@ class FakeTransport(object):
         try:
             self.other.dataReceived(to_write)
         except Exception as e:
-            logger.warning("Exception writing to protocol: %s", e)
+            logger.exception("Exception writing to protocol: %s", e)
             return
 
         self.buffer = self.buffer[len(to_write) :]