diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index a447ce4e55..214eb17fbc 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -3,23 +3,52 @@
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
# private class.
-
from socket import (
AF_INET,
AF_INET6,
AF_UNSPEC,
SOCK_DGRAM,
SOCK_STREAM,
+ AddressFamily,
+ SocketKind,
gaierror,
getaddrinfo,
)
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ List,
+ NoReturn,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.interfaces import (
+ IAddress,
+ IHostnameResolver,
+ IHostResolution,
+ IReactorThreads,
+ IResolutionReceiver,
+)
from twisted.internet.threads import deferToThreadPool
+if TYPE_CHECKING:
+ # The types below are copied from
+ # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
+ # so that the type hints can match the interfaces.
+ from twisted.python.runtime import platform
+
+ if platform.supportsThreads():
+ from twisted.python.threadpool import ThreadPool
+ else:
+ ThreadPool = object # type: ignore[misc, assignment]
+
@implementer(IHostResolution)
class HostResolution:
@@ -27,13 +56,13 @@ class HostResolution:
The in-progress resolution of a given hostname.
"""
- def __init__(self, name):
+ def __init__(self, name: str):
"""
Create a L{HostResolution} with the given name.
"""
self.name = name
- def cancel(self):
+ def cancel(self) -> NoReturn:
# IHostResolution.cancel
raise NotImplementedError()
@@ -62,6 +91,17 @@ _socktypeToType = {
}
+_GETADDRINFO_RESULT = List[
+ Tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ Union[Tuple[str, int], Tuple[str, int, int, int]],
+ ]
+]
+
+
@implementer(IHostnameResolver)
class GAIResolver:
"""
@@ -69,7 +109,12 @@ class GAIResolver:
L{getaddrinfo} in a thread.
"""
- def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+ def __init__(
+ self,
+ reactor: IReactorThreads,
+ getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
+ getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
+ ):
"""
Create a L{GAIResolver}.
@param reactor: the reactor to schedule result-delivery on
@@ -89,14 +134,16 @@ class GAIResolver:
)
self._getaddrinfo = getaddrinfo
- def resolveHostName(
+ # The types on IHostnameResolver is incorrect in Twisted, see
+ # https://twistedmatrix.com/trac/ticket/10276
+ def resolveHostName( # type: ignore[override]
self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics="TCP",
- ):
+ resolutionReceiver: IResolutionReceiver,
+ hostName: str,
+ portNumber: int = 0,
+ addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+ transportSemantics: str = "TCP",
+ ) -> IHostResolution:
"""
See L{IHostnameResolver.resolveHostName}
@param resolutionReceiver: see interface
@@ -112,7 +159,7 @@ class GAIResolver:
]
socketType = _transportToSocket[transportSemantics]
- def get():
+ def get() -> _GETADDRINFO_RESULT:
try:
return self._getaddrinfo(
hostName, portNumber, addressFamily, socketType
@@ -125,7 +172,7 @@ class GAIResolver:
resolutionReceiver.resolutionBegan(resolution)
@d.addCallback
- def deliverResults(result):
+ def deliverResults(result: _GETADDRINFO_RESULT) -> None:
for family, socktype, _proto, _cannoname, sockaddr in result:
addrType = _afToType[family]
resolutionReceiver.addressResolved(
|