summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py18
-rw-r--r--tests/handlers/test_cas.py2
-rw-r--r--tests/handlers/test_oidc.py4
-rw-r--r--tests/handlers/test_saml.py2
-rw-r--r--tests/replication/_base.py20
-rw-r--r--tests/server.py13
6 files changed, 33 insertions, 26 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3e05789923..d547df8a64 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
@@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "192.168.10.10"
+        request.getClientAddress.return_value.host = "192.168.10.10"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
@@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "131.111.8.42"
+        request.getClientAddress.return_value.host = "131.111.8.42"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         f = self.get_failure(
@@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_device = simple_async_mock({"hidden": False})
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_device = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         self.store.insert_client_ip = simple_async_mock(None)
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_success(self.auth.get_user_by_req(request))
@@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         self.store.insert_client_ip = simple_async_mock(None)
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_success(self.auth.get_user_by_req(request))
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 751025c5da..2b21547d0f 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -204,7 +204,7 @@ def _mock_request():
     mock = Mock(
         spec=[
             "finish",
-            "getClientIP",
+            "getClientAddress",
             "getHeader",
             "setHeader",
             "setResponseCode",
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 9684120c70..1231aed944 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -1300,7 +1300,7 @@ def _build_callback_request(
             "getCookie",
             "cookies",
             "requestHeaders",
-            "getClientIP",
+            "getClientAddress",
             "getHeader",
         ]
     )
@@ -1310,5 +1310,5 @@ def _build_callback_request(
     request.args = {}
     request.args[b"code"] = [code.encode("utf-8")]
     request.args[b"state"] = [state.encode("utf-8")]
-    request.getClientIP.return_value = ip_address
+    request.getClientAddress.return_value.host = ip_address
     return request
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index e2f0f90ef1..a0f84e2940 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -352,7 +352,7 @@ def _mock_request():
     mock = Mock(
         spec=[
             "finish",
-            "getClientIP",
+            "getClientAddress",
             "getHeader",
             "setHeader",
             "setResponseCode",
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a0589b6d6a..a7602b4c96 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(port, 8765)
 
         # Set up client side protocol
-        client_protocol = client_factory.buildProtocol(None)
+        client_address = IPv4Address("TCP", "127.0.0.1", 1234)
+        client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
 
         # Set up the server side protocol
-        channel = self.site.buildProtocol(None)
+        server_address = IPv4Address("TCP", host, port)
+        channel = self.site.buildProtocol((host, port))
 
         # hook into the channel's request factory so that we can keep a record
         # of the requests
@@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
-            channel, self.reactor, client_protocol
+            channel, self.reactor, client_protocol, server_address, client_address
         )
         client_protocol.makeConnection(client_to_server_transport)
 
         server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, channel
+            client_protocol, self.reactor, channel, client_address, server_address
         )
         channel.makeConnection(server_to_client_transport)
 
@@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(port, repl_port)
 
         # Set up client side protocol
-        client_protocol = client_factory.buildProtocol(None)
+        client_address = IPv4Address("TCP", "127.0.0.1", 1234)
+        client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
 
         # Set up the server side protocol
-        channel = self._hs_to_site[hs].buildProtocol(None)
+        server_address = IPv4Address("TCP", host, port)
+        channel = self._hs_to_site[hs].buildProtocol((host, port))
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
-            channel, self.reactor, client_protocol
+            channel, self.reactor, client_protocol, server_address, client_address
         )
         client_protocol.makeConnection(client_to_server_transport)
 
         server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, channel
+            client_protocol, self.reactor, channel, client_address, server_address
         )
         channel.makeConnection(server_to_client_transport)
 
diff --git a/tests/server.py b/tests/server.py
index 16559d2588..8f30e250c8 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -181,7 +181,7 @@ class FakeChannel:
             self.resource_usage = _self.logcontext.get_resource_usage()
 
     def getPeer(self):
-        # We give an address so that getClientIP returns a non null entry,
+        # We give an address so that getClientAddress/getClientIP returns a non null entry,
         # causing us to record the MAU
         return address.IPv4Address("TCP", self._ip, 3423)
 
@@ -562,7 +562,10 @@ class FakeTransport:
     """
 
     _peer_address: Optional[IAddress] = attr.ib(default=None)
-    """The value to be returend by getPeer"""
+    """The value to be returned by getPeer"""
+
+    _host_address: Optional[IAddress] = attr.ib(default=None)
+    """The value to be returned by getHost"""
 
     disconnecting = False
     disconnected = False
@@ -571,11 +574,11 @@ class FakeTransport:
     producer = attr.ib(default=None)
     autoflush = attr.ib(default=True)
 
-    def getPeer(self):
+    def getPeer(self) -> Optional[IAddress]:
         return self._peer_address
 
-    def getHost(self):
-        return None
+    def getHost(self) -> Optional[IAddress]:
+        return self._host_address
 
     def loseConnection(self, reason=None):
         if not self.disconnecting: