summary refs log tree commit diff
path: root/synapse/util/gai_resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/gai_resolver.py')
-rw-r--r--synapse/util/gai_resolver.py75
1 files changed, 61 insertions, 14 deletions
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index a447ce4e55..214eb17fbc 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -3,23 +3,52 @@
 # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
 # private class.
 
-
 from socket import (
     AF_INET,
     AF_INET6,
     AF_UNSPEC,
     SOCK_DGRAM,
     SOCK_STREAM,
+    AddressFamily,
+    SocketKind,
     gaierror,
     getaddrinfo,
 )
+from typing import (
+    TYPE_CHECKING,
+    Callable,
+    List,
+    NoReturn,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 from zope.interface import implementer
 
 from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.interfaces import (
+    IAddress,
+    IHostnameResolver,
+    IHostResolution,
+    IReactorThreads,
+    IResolutionReceiver,
+)
 from twisted.internet.threads import deferToThreadPool
 
+if TYPE_CHECKING:
+    # The types below are copied from
+    # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
+    # so that the type hints can match the interfaces.
+    from twisted.python.runtime import platform
+
+    if platform.supportsThreads():
+        from twisted.python.threadpool import ThreadPool
+    else:
+        ThreadPool = object  # type: ignore[misc, assignment]
+
 
 @implementer(IHostResolution)
 class HostResolution:
@@ -27,13 +56,13 @@ class HostResolution:
     The in-progress resolution of a given hostname.
     """
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         """
         Create a L{HostResolution} with the given name.
         """
         self.name = name
 
-    def cancel(self):
+    def cancel(self) -> NoReturn:
         # IHostResolution.cancel
         raise NotImplementedError()
 
@@ -62,6 +91,17 @@ _socktypeToType = {
 }
 
 
+_GETADDRINFO_RESULT = List[
+    Tuple[
+        AddressFamily,
+        SocketKind,
+        int,
+        str,
+        Union[Tuple[str, int], Tuple[str, int, int, int]],
+    ]
+]
+
+
 @implementer(IHostnameResolver)
 class GAIResolver:
     """
@@ -69,7 +109,12 @@ class GAIResolver:
     L{getaddrinfo} in a thread.
     """
 
-    def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+    def __init__(
+        self,
+        reactor: IReactorThreads,
+        getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
+        getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
+    ):
         """
         Create a L{GAIResolver}.
         @param reactor: the reactor to schedule result-delivery on
@@ -89,14 +134,16 @@ class GAIResolver:
         )
         self._getaddrinfo = getaddrinfo
 
-    def resolveHostName(
+    # The types on IHostnameResolver is incorrect in Twisted, see
+    # https://twistedmatrix.com/trac/ticket/10276
+    def resolveHostName(  # type: ignore[override]
         self,
-        resolutionReceiver,
-        hostName,
-        portNumber=0,
-        addressTypes=None,
-        transportSemantics="TCP",
-    ):
+        resolutionReceiver: IResolutionReceiver,
+        hostName: str,
+        portNumber: int = 0,
+        addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+        transportSemantics: str = "TCP",
+    ) -> IHostResolution:
         """
         See L{IHostnameResolver.resolveHostName}
         @param resolutionReceiver: see interface
@@ -112,7 +159,7 @@ class GAIResolver:
         ]
         socketType = _transportToSocket[transportSemantics]
 
-        def get():
+        def get() -> _GETADDRINFO_RESULT:
             try:
                 return self._getaddrinfo(
                     hostName, portNumber, addressFamily, socketType
@@ -125,7 +172,7 @@ class GAIResolver:
         resolutionReceiver.resolutionBegan(resolution)
 
         @d.addCallback
-        def deliverResults(result):
+        def deliverResults(result: _GETADDRINFO_RESULT) -> None:
             for family, socktype, _proto, _cannoname, sockaddr in result:
                 addrType = _afToType[family]
                 resolutionReceiver.addressResolved(