summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py13
-rw-r--r--synapse/util/gai_resolver.py136
2 files changed, 148 insertions, 1 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 2ca2e051e4..03627cdcba 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -31,6 +31,7 @@ import twisted
 from twisted.internet import defer, error, reactor
 from twisted.logger import LoggingFile, LogLevel
 from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.python.threadpool import ThreadPool
 
 import synapse
 from synapse.api.constants import MAX_PDU_SIZE
@@ -48,6 +49,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import setup_jemalloc_stats
 from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
 from synapse.util.daemonize import daemonize_process
+from synapse.util.gai_resolver import GAIResolver
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
 
@@ -338,9 +340,18 @@ async def start(hs: "HomeServer"):
     Args:
         hs: homeserver instance
     """
+    reactor = hs.get_reactor()
+
+    # We want to use a separate thread pool for the resolver so that large
+    # numbers of DNS requests don't starve out other users of the threadpool.
+    resolver_threadpool = ThreadPool(name="gai_resolver")
+    resolver_threadpool.start()
+    reactor.installNameResolver(
+        GAIResolver(reactor, getThreadPool=lambda: resolver_threadpool)
+    )
+
     # Set up the SIGHUP machinery.
     if hasattr(signal, "SIGHUP"):
-        reactor = hs.get_reactor()
 
         @wrap_as_background_process("sighup")
         def handle_sighup(*args, **kwargs):
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
new file mode 100644
index 0000000000..a447ce4e55
--- /dev/null
+++ b/synapse/util/gai_resolver.py
@@ -0,0 +1,136 @@
+# This is a direct lift from
+# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/_resolver.py.
+# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
+# private class.
+
+
+from socket import (
+    AF_INET,
+    AF_INET6,
+    AF_UNSPEC,
+    SOCK_DGRAM,
+    SOCK_STREAM,
+    gaierror,
+    getaddrinfo,
+)
+
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.threads import deferToThreadPool
+
+
+@implementer(IHostResolution)
+class HostResolution:
+    """
+    The in-progress resolution of a given hostname.
+    """
+
+    def __init__(self, name):
+        """
+        Create a L{HostResolution} with the given name.
+        """
+        self.name = name
+
+    def cancel(self):
+        # IHostResolution.cancel
+        raise NotImplementedError()
+
+
+_any = frozenset([IPv4Address, IPv6Address])
+
+_typesToAF = {
+    frozenset([IPv4Address]): AF_INET,
+    frozenset([IPv6Address]): AF_INET6,
+    _any: AF_UNSPEC,
+}
+
+_afToType = {
+    AF_INET: IPv4Address,
+    AF_INET6: IPv6Address,
+}
+
+_transportToSocket = {
+    "TCP": SOCK_STREAM,
+    "UDP": SOCK_DGRAM,
+}
+
+_socktypeToType = {
+    SOCK_STREAM: "TCP",
+    SOCK_DGRAM: "UDP",
+}
+
+
+@implementer(IHostnameResolver)
+class GAIResolver:
+    """
+    L{IHostnameResolver} implementation that resolves hostnames by calling
+    L{getaddrinfo} in a thread.
+    """
+
+    def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+        """
+        Create a L{GAIResolver}.
+        @param reactor: the reactor to schedule result-delivery on
+        @type reactor: L{IReactorThreads}
+        @param getThreadPool: a function to retrieve the thread pool to use for
+            scheduling name resolutions.  If not supplied, the use the given
+            C{reactor}'s thread pool.
+        @type getThreadPool: 0-argument callable returning a
+            L{twisted.python.threadpool.ThreadPool}
+        @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly
+            parameterized for testing.
+        @type getaddrinfo: callable with the same signature as L{getaddrinfo}
+        """
+        self._reactor = reactor
+        self._getThreadPool = (
+            reactor.getThreadPool if getThreadPool is None else getThreadPool
+        )
+        self._getaddrinfo = getaddrinfo
+
+    def resolveHostName(
+        self,
+        resolutionReceiver,
+        hostName,
+        portNumber=0,
+        addressTypes=None,
+        transportSemantics="TCP",
+    ):
+        """
+        See L{IHostnameResolver.resolveHostName}
+        @param resolutionReceiver: see interface
+        @param hostName: see interface
+        @param portNumber: see interface
+        @param addressTypes: see interface
+        @param transportSemantics: see interface
+        @return: see interface
+        """
+        pool = self._getThreadPool()
+        addressFamily = _typesToAF[
+            _any if addressTypes is None else frozenset(addressTypes)
+        ]
+        socketType = _transportToSocket[transportSemantics]
+
+        def get():
+            try:
+                return self._getaddrinfo(
+                    hostName, portNumber, addressFamily, socketType
+                )
+            except gaierror:
+                return []
+
+        d = deferToThreadPool(self._reactor, pool, get)
+        resolution = HostResolution(hostName)
+        resolutionReceiver.resolutionBegan(resolution)
+
+        @d.addCallback
+        def deliverResults(result):
+            for family, socktype, _proto, _cannoname, sockaddr in result:
+                addrType = _afToType[family]
+                resolutionReceiver.addressResolved(
+                    addrType(_socktypeToType.get(socktype, "TCP"), *sockaddr)
+                )
+            resolutionReceiver.resolutionComplete()
+
+        return resolution