diff options
Diffstat (limited to 'synapse/util/gai_resolver.py')
-rw-r--r-- | synapse/util/gai_resolver.py | 75 |
1 files changed, 61 insertions, 14 deletions
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py index a447ce4e55..214eb17fbc 100644 --- a/synapse/util/gai_resolver.py +++ b/synapse/util/gai_resolver.py @@ -3,23 +3,52 @@ # We copy it here as we need to instantiate `GAIResolver` manually, but it is a # private class. - from socket import ( AF_INET, AF_INET6, AF_UNSPEC, SOCK_DGRAM, SOCK_STREAM, + AddressFamily, + SocketKind, gaierror, getaddrinfo, ) +from typing import ( + TYPE_CHECKING, + Callable, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) from zope.interface import implementer from twisted.internet.address import IPv4Address, IPv6Address -from twisted.internet.interfaces import IHostnameResolver, IHostResolution +from twisted.internet.interfaces import ( + IAddress, + IHostnameResolver, + IHostResolution, + IReactorThreads, + IResolutionReceiver, +) from twisted.internet.threads import deferToThreadPool +if TYPE_CHECKING: + # The types below are copied from + # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py + # so that the type hints can match the interfaces. + from twisted.python.runtime import platform + + if platform.supportsThreads(): + from twisted.python.threadpool import ThreadPool + else: + ThreadPool = object # type: ignore[misc, assignment] + @implementer(IHostResolution) class HostResolution: @@ -27,13 +56,13 @@ class HostResolution: The in-progress resolution of a given hostname. """ - def __init__(self, name): + def __init__(self, name: str): """ Create a L{HostResolution} with the given name. """ self.name = name - def cancel(self): + def cancel(self) -> NoReturn: # IHostResolution.cancel raise NotImplementedError() @@ -62,6 +91,17 @@ _socktypeToType = { } +_GETADDRINFO_RESULT = List[ + Tuple[ + AddressFamily, + SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] +] + + @implementer(IHostnameResolver) class GAIResolver: """ @@ -69,7 +109,12 @@ class GAIResolver: L{getaddrinfo} in a thread. """ - def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): + def __init__( + self, + reactor: IReactorThreads, + getThreadPool: Optional[Callable[[], "ThreadPool"]] = None, + getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo, + ): """ Create a L{GAIResolver}. @param reactor: the reactor to schedule result-delivery on @@ -89,14 +134,16 @@ class GAIResolver: ) self._getaddrinfo = getaddrinfo - def resolveHostName( + # The types on IHostnameResolver is incorrect in Twisted, see + # https://twistedmatrix.com/trac/ticket/10276 + def resolveHostName( # type: ignore[override] self, - resolutionReceiver, - hostName, - portNumber=0, - addressTypes=None, - transportSemantics="TCP", - ): + resolutionReceiver: IResolutionReceiver, + hostName: str, + portNumber: int = 0, + addressTypes: Optional[Sequence[Type[IAddress]]] = None, + transportSemantics: str = "TCP", + ) -> IHostResolution: """ See L{IHostnameResolver.resolveHostName} @param resolutionReceiver: see interface @@ -112,7 +159,7 @@ class GAIResolver: ] socketType = _transportToSocket[transportSemantics] - def get(): + def get() -> _GETADDRINFO_RESULT: try: return self._getaddrinfo( hostName, portNumber, addressFamily, socketType @@ -125,7 +172,7 @@ class GAIResolver: resolutionReceiver.resolutionBegan(resolution) @d.addCallback - def deliverResults(result): + def deliverResults(result: _GETADDRINFO_RESULT) -> None: for family, socktype, _proto, _cannoname, sockaddr in result: addrType = _afToType[family] resolutionReceiver.addressResolved( |