diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5df80ea8e7..96efc5f3e3 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,11 +22,11 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Generic,
Hashable,
Iterable,
- List,
Optional,
Set,
TypeVar,
@@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", set())
+ object.__setattr__(self, "_observers", [])
def callback(r):
object.__setattr__(self, "_result", (True, r))
- while self._observers:
- observer = self._observers.pop()
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
try:
observer.callback(r)
except Exception as e:
@@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
def errback(f):
object.__setattr__(self, "_result", (False, f))
- while self._observers:
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
-
- observer = self._observers.pop()
try:
observer.errback(f)
except Exception as e:
@@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
"""
if not self._result:
d: "defer.Deferred[_T]" = defer.Deferred()
-
- def remove(r):
- self._observers.discard(d)
- return r
-
- d.addBoth(remove)
-
- self._observers.add(d)
+ self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> "List[defer.Deferred[_T]]":
+ def observers(self) -> "Collection[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
new file mode 100644
index 0000000000..a447ce4e55
--- /dev/null
+++ b/synapse/util/gai_resolver.py
@@ -0,0 +1,136 @@
+# This is a direct lift from
+# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/_resolver.py.
+# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
+# private class.
+
+
+from socket import (
+ AF_INET,
+ AF_INET6,
+ AF_UNSPEC,
+ SOCK_DGRAM,
+ SOCK_STREAM,
+ gaierror,
+ getaddrinfo,
+)
+
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.threads import deferToThreadPool
+
+
+@implementer(IHostResolution)
+class HostResolution:
+ """
+ The in-progress resolution of a given hostname.
+ """
+
+ def __init__(self, name):
+ """
+ Create a L{HostResolution} with the given name.
+ """
+ self.name = name
+
+ def cancel(self):
+ # IHostResolution.cancel
+ raise NotImplementedError()
+
+
+_any = frozenset([IPv4Address, IPv6Address])
+
+_typesToAF = {
+ frozenset([IPv4Address]): AF_INET,
+ frozenset([IPv6Address]): AF_INET6,
+ _any: AF_UNSPEC,
+}
+
+_afToType = {
+ AF_INET: IPv4Address,
+ AF_INET6: IPv6Address,
+}
+
+_transportToSocket = {
+ "TCP": SOCK_STREAM,
+ "UDP": SOCK_DGRAM,
+}
+
+_socktypeToType = {
+ SOCK_STREAM: "TCP",
+ SOCK_DGRAM: "UDP",
+}
+
+
+@implementer(IHostnameResolver)
+class GAIResolver:
+ """
+ L{IHostnameResolver} implementation that resolves hostnames by calling
+ L{getaddrinfo} in a thread.
+ """
+
+ def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+ """
+ Create a L{GAIResolver}.
+ @param reactor: the reactor to schedule result-delivery on
+ @type reactor: L{IReactorThreads}
+ @param getThreadPool: a function to retrieve the thread pool to use for
+ scheduling name resolutions. If not supplied, the use the given
+ C{reactor}'s thread pool.
+ @type getThreadPool: 0-argument callable returning a
+ L{twisted.python.threadpool.ThreadPool}
+ @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly
+ parameterized for testing.
+ @type getaddrinfo: callable with the same signature as L{getaddrinfo}
+ """
+ self._reactor = reactor
+ self._getThreadPool = (
+ reactor.getThreadPool if getThreadPool is None else getThreadPool
+ )
+ self._getaddrinfo = getaddrinfo
+
+ def resolveHostName(
+ self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics="TCP",
+ ):
+ """
+ See L{IHostnameResolver.resolveHostName}
+ @param resolutionReceiver: see interface
+ @param hostName: see interface
+ @param portNumber: see interface
+ @param addressTypes: see interface
+ @param transportSemantics: see interface
+ @return: see interface
+ """
+ pool = self._getThreadPool()
+ addressFamily = _typesToAF[
+ _any if addressTypes is None else frozenset(addressTypes)
+ ]
+ socketType = _transportToSocket[transportSemantics]
+
+ def get():
+ try:
+ return self._getaddrinfo(
+ hostName, portNumber, addressFamily, socketType
+ )
+ except gaierror:
+ return []
+
+ d = deferToThreadPool(self._reactor, pool, get)
+ resolution = HostResolution(hostName)
+ resolutionReceiver.resolutionBegan(resolution)
+
+ @d.addCallback
+ def deliverResults(result):
+ for family, socktype, _proto, _cannoname, sockaddr in result:
+ addrType = _afToType[family]
+ resolutionReceiver.addressResolved(
+ addrType(_socktypeToType.get(socktype, "TCP"), *sockaddr)
+ )
+ resolutionReceiver.resolutionComplete()
+
+ return resolution
|