summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-05-04 14:11:21 -0400
committerGitHub <noreply@github.com>2022-05-04 14:11:21 -0400
commit7fbf42499d92ec3c9a05d9f36ec5fecd1ab1f18c (patch)
tree5f3a08745f204376a211aa3ece6f92f4108a9891
parentImplement changes to MSC2285 (hidden read receipts) (#12168) (diff)
downloadsynapse-7fbf42499d92ec3c9a05d9f36ec5fecd1ab1f18c.tar.xz
Use `getClientAddress` instead of `getClientIP`. (#12599)
getClientIP was deprecated in Twisted 18.4.0, which also added
getClientAddress. The Synapse minimum version for Twisted is
currently 18.9.0, so all supported versions have the new API.
Diffstat (limited to '')
-rw-r--r--changelog.d/12599.misc1
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/handlers/auth.py2
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/sso.py4
-rw-r--r--synapse/http/site.py6
-rw-r--r--synapse/logging/opentracing.py2
-rw-r--r--synapse/rest/client/auth.py8
-rw-r--r--synapse/rest/client/login.py14
-rw-r--r--synapse/rest/client/register.py6
-rw-r--r--tests/api/test_auth.py18
-rw-r--r--tests/handlers/test_cas.py2
-rw-r--r--tests/handlers/test_oidc.py4
-rw-r--r--tests/handlers/test_saml.py2
-rw-r--r--tests/replication/_base.py20
-rw-r--r--tests/server.py13
16 files changed, 62 insertions, 46 deletions
diff --git a/changelog.d/12599.misc b/changelog.d/12599.misc
new file mode 100644
index 0000000000..d01278bbce
--- /dev/null
+++ b/changelog.d/12599.misc
@@ -0,0 +1 @@
+Use `getClientAddress` instead of the deprecated `getClientIP`.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 01c32417d8..f6202ef7a5 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -187,7 +187,7 @@ class Auth:
         Once get_user_by_req has set up the opentracing span, this does the actual work.
         """
         try:
-            ip_addr = request.getClientIP()
+            ip_addr = request.getClientAddress().host
             user_agent = get_request_user_agent(request)
 
             access_token = self.get_access_token_from_request(request)
@@ -356,7 +356,7 @@ class Auth:
             return None, None, None
 
         if app_service.ip_range_whitelist:
-            ip_address = IPAddress(request.getClientIP())
+            ip_address = IPAddress(request.getClientAddress().host)
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None, None
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 22678d486d..ad41337b28 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -551,7 +551,7 @@ class AuthHandler:
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
         user_agent = get_request_user_agent(request)
-        clientip = request.getClientIP()
+        clientip = request.getClientAddress().host
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c183e9c465..9bca2bc4b2 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -92,7 +92,7 @@ class IdentityHandler:
         """
 
         await self._3pid_validation_ratelimiter_ip.ratelimit(
-            None, (medium, request.getClientIP())
+            None, (medium, request.getClientAddress().host)
         )
         await self._3pid_validation_ratelimiter_address.ratelimit(
             None, (medium, address)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e4fe94e557..1e171f3f71 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -468,7 +468,7 @@ class SsoHandler:
                     auth_provider_id,
                     remote_user_id,
                     get_request_user_agent(request),
-                    request.getClientIP(),
+                    request.getClientAddress().host,
                 )
                 new_user = True
             elif self._sso_update_profile_information:
@@ -928,7 +928,7 @@ class SsoHandler:
             session.auth_provider_id,
             session.remote_user_id,
             get_request_user_agent(request),
-            request.getClientIP(),
+            request.getClientAddress().host,
         )
 
         logger.info(
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 40f6c04894..0b85a57d77 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -238,7 +238,7 @@ class SynapseRequest(Request):
             request_id,
             request=ContextRequest(
                 request_id=request_id,
-                ip_address=self.getClientIP(),
+                ip_address=self.getClientAddress().host,
                 site_tag=self.synapse_site.site_tag,
                 # The requester is going to be unknown at this point.
                 requester=None,
@@ -381,7 +381,7 @@ class SynapseRequest(Request):
 
         self.synapse_site.access_logger.debug(
             "%s - %s - Received request: %s %s",
-            self.getClientIP(),
+            self.getClientAddress().host,
             self.synapse_site.site_tag,
             self.get_method(),
             self.get_redacted_uri(),
@@ -429,7 +429,7 @@ class SynapseRequest(Request):
             "%s - %s - {%s}"
             " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
             ' %sB %s "%s %s %s" "%s" [%d dbevts]',
-            self.getClientIP(),
+            self.getClientAddress().host,
             self.synapse_site.site_tag,
             requester,
             processing_time,
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index f86ee9aac7..a02b5bf6bd 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
         tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
         tags.HTTP_METHOD: request.get_method(),
         tags.HTTP_URL: request.get_redacted_uri(),
-        tags.PEER_HOST_IPV6: request.getClientIP(),
+        tags.PEER_HOST_IPV6: request.getClientAddress().host,
     }
 
     request_name = request.request_metrics.name
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index e0b2b80e5b..eb77337044 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -112,7 +112,7 @@ class AuthRestServlet(RestServlet):
 
             try:
                 await self.auth_handler.add_oob_auth(
-                    LoginType.RECAPTCHA, authdict, request.getClientIP()
+                    LoginType.RECAPTCHA, authdict, request.getClientAddress().host
                 )
             except LoginError as e:
                 # Authentication failed, let user try again
@@ -132,7 +132,7 @@ class AuthRestServlet(RestServlet):
 
             try:
                 await self.auth_handler.add_oob_auth(
-                    LoginType.TERMS, authdict, request.getClientIP()
+                    LoginType.TERMS, authdict, request.getClientAddress().host
                 )
             except LoginError as e:
                 # Authentication failed, let user try again
@@ -161,7 +161,9 @@ class AuthRestServlet(RestServlet):
 
             try:
                 await self.auth_handler.add_oob_auth(
-                    LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+                    LoginType.REGISTRATION_TOKEN,
+                    authdict,
+                    request.getClientAddress().host,
                 )
             except LoginError as e:
                 html = self.registration_token_template.render(
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 71d8038448..cf4196ac0a 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -176,7 +176,7 @@ class LoginRestServlet(RestServlet):
 
                 if appservice.is_rate_limited():
                     await self._address_ratelimiter.ratelimit(
-                        None, request.getClientIP()
+                        None, request.getClientAddress().host
                     )
 
                 result = await self._do_appservice_login(
@@ -188,19 +188,25 @@ class LoginRestServlet(RestServlet):
                 self.jwt_enabled
                 and login_submission["type"] == LoginRestServlet.JWT_TYPE
             ):
-                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
+                await self._address_ratelimiter.ratelimit(
+                    None, request.getClientAddress().host
+                )
                 result = await self._do_jwt_login(
                     login_submission,
                     should_issue_refresh_token=should_issue_refresh_token,
                 )
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
-                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
+                await self._address_ratelimiter.ratelimit(
+                    None, request.getClientAddress().host
+                )
                 result = await self._do_token_login(
                     login_submission,
                     should_issue_refresh_token=should_issue_refresh_token,
                 )
             else:
-                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
+                await self._address_ratelimiter.ratelimit(
+                    None, request.getClientAddress().host
+                )
                 result = await self._do_other_login(
                     login_submission,
                     should_issue_refresh_token=should_issue_refresh_token,
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 13ef6b35a0..47b6db1ebf 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -352,7 +352,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
         if self.inhibit_user_in_use_error:
             return 200, {"available": True}
 
-        ip = request.getClientIP()
+        ip = request.getClientAddress().host
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -394,7 +394,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
         )
 
     async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
-        await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+        await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
 
         if not self.hs.config.registration.enable_registration:
             raise SynapseError(
@@ -441,7 +441,7 @@ class RegisterRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
-        client_addr = request.getClientIP()
+        client_addr = request.getClientAddress().host
 
         await self.ratelimiter.ratelimit(None, client_addr, update=False)
 
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3e05789923..d547df8a64 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
@@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "192.168.10.10"
+        request.getClientAddress.return_value.host = "192.168.10.10"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
@@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "131.111.8.42"
+        request.getClientAddress.return_value.host = "131.111.8.42"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         f = self.get_failure(
@@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_device = simple_async_mock({"hidden": False})
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_device = simple_async_mock(None)
 
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         self.store.insert_client_ip = simple_async_mock(None)
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_success(self.auth.get_user_by_req(request))
@@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         self.store.insert_client_ip = simple_async_mock(None)
         request = Mock(args={})
-        request.getClientIP.return_value = "127.0.0.1"
+        request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_success(self.auth.get_user_by_req(request))
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 751025c5da..2b21547d0f 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -204,7 +204,7 @@ def _mock_request():
     mock = Mock(
         spec=[
             "finish",
-            "getClientIP",
+            "getClientAddress",
             "getHeader",
             "setHeader",
             "setResponseCode",
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 9684120c70..1231aed944 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -1300,7 +1300,7 @@ def _build_callback_request(
             "getCookie",
             "cookies",
             "requestHeaders",
-            "getClientIP",
+            "getClientAddress",
             "getHeader",
         ]
     )
@@ -1310,5 +1310,5 @@ def _build_callback_request(
     request.args = {}
     request.args[b"code"] = [code.encode("utf-8")]
     request.args[b"state"] = [state.encode("utf-8")]
-    request.getClientIP.return_value = ip_address
+    request.getClientAddress.return_value.host = ip_address
     return request
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index e2f0f90ef1..a0f84e2940 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -352,7 +352,7 @@ def _mock_request():
     mock = Mock(
         spec=[
             "finish",
-            "getClientIP",
+            "getClientAddress",
             "getHeader",
             "setHeader",
             "setResponseCode",
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a0589b6d6a..a7602b4c96 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(port, 8765)
 
         # Set up client side protocol
-        client_protocol = client_factory.buildProtocol(None)
+        client_address = IPv4Address("TCP", "127.0.0.1", 1234)
+        client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
 
         # Set up the server side protocol
-        channel = self.site.buildProtocol(None)
+        server_address = IPv4Address("TCP", host, port)
+        channel = self.site.buildProtocol((host, port))
 
         # hook into the channel's request factory so that we can keep a record
         # of the requests
@@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
-            channel, self.reactor, client_protocol
+            channel, self.reactor, client_protocol, server_address, client_address
         )
         client_protocol.makeConnection(client_to_server_transport)
 
         server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, channel
+            client_protocol, self.reactor, channel, client_address, server_address
         )
         channel.makeConnection(server_to_client_transport)
 
@@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(port, repl_port)
 
         # Set up client side protocol
-        client_protocol = client_factory.buildProtocol(None)
+        client_address = IPv4Address("TCP", "127.0.0.1", 1234)
+        client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
 
         # Set up the server side protocol
-        channel = self._hs_to_site[hs].buildProtocol(None)
+        server_address = IPv4Address("TCP", host, port)
+        channel = self._hs_to_site[hs].buildProtocol((host, port))
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
-            channel, self.reactor, client_protocol
+            channel, self.reactor, client_protocol, server_address, client_address
         )
         client_protocol.makeConnection(client_to_server_transport)
 
         server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, channel
+            client_protocol, self.reactor, channel, client_address, server_address
         )
         channel.makeConnection(server_to_client_transport)
 
diff --git a/tests/server.py b/tests/server.py
index 16559d2588..8f30e250c8 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -181,7 +181,7 @@ class FakeChannel:
             self.resource_usage = _self.logcontext.get_resource_usage()
 
     def getPeer(self):
-        # We give an address so that getClientIP returns a non null entry,
+        # We give an address so that getClientAddress/getClientIP returns a non null entry,
         # causing us to record the MAU
         return address.IPv4Address("TCP", self._ip, 3423)
 
@@ -562,7 +562,10 @@ class FakeTransport:
     """
 
     _peer_address: Optional[IAddress] = attr.ib(default=None)
-    """The value to be returend by getPeer"""
+    """The value to be returned by getPeer"""
+
+    _host_address: Optional[IAddress] = attr.ib(default=None)
+    """The value to be returned by getHost"""
 
     disconnecting = False
     disconnected = False
@@ -571,11 +574,11 @@ class FakeTransport:
     producer = attr.ib(default=None)
     autoflush = attr.ib(default=True)
 
-    def getPeer(self):
+    def getPeer(self) -> Optional[IAddress]:
         return self._peer_address
 
-    def getHost(self):
-        return None
+    def getHost(self) -> Optional[IAddress]:
+        return self._host_address
 
     def loseConnection(self, reason=None):
         if not self.disconnecting: