summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a0589b6d6a..a7602b4c96 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(port, 8765)
 
         # Set up client side protocol
-        client_protocol = client_factory.buildProtocol(None)
+        client_address = IPv4Address("TCP", "127.0.0.1", 1234)
+        client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
 
         # Set up the server side protocol
-        channel = self.site.buildProtocol(None)
+        server_address = IPv4Address("TCP", host, port)
+        channel = self.site.buildProtocol((host, port))
 
         # hook into the channel's request factory so that we can keep a record
         # of the requests
@@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
-            channel, self.reactor, client_protocol
+            channel, self.reactor, client_protocol, server_address, client_address
         )
         client_protocol.makeConnection(client_to_server_transport)
 
         server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, channel
+            client_protocol, self.reactor, channel, client_address, server_address
         )
         channel.makeConnection(server_to_client_transport)
 
@@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(port, repl_port)
 
         # Set up client side protocol
-        client_protocol = client_factory.buildProtocol(None)
+        client_address = IPv4Address("TCP", "127.0.0.1", 1234)
+        client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
 
         # Set up the server side protocol
-        channel = self._hs_to_site[hs].buildProtocol(None)
+        server_address = IPv4Address("TCP", host, port)
+        channel = self._hs_to_site[hs].buildProtocol((host, port))
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
-            channel, self.reactor, client_protocol
+            channel, self.reactor, client_protocol, server_address, client_address
         )
         client_protocol.makeConnection(client_to_server_transport)
 
         server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, channel
+            client_protocol, self.reactor, channel, client_address, server_address
         )
         channel.makeConnection(server_to_client_transport)