diff --git a/changelog.d/7874.misc b/changelog.d/7874.misc
new file mode 100644
index 0000000000..f75c8d1843
--- /dev/null
+++ b/changelog.d/7874.misc
@@ -0,0 +1 @@
+Convert the federation agent and related code to async/await.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index c5fc746f2f..0c02648015 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -15,6 +15,7 @@
import logging
import urllib
+from typing import List
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
@@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object):
return run_in_background(self._do_connect, protocol_factory)
- @defer.inlineCallbacks
- def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory):
first_exception = None
- server_list = yield self._resolve_server()
+ server_list = await self._resolve_server()
for server in server_list:
host = server.host
@@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object):
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
- result = yield make_deferred_yieldable(
+ result = await make_deferred_yieldable(
endpoint.connect(protocol_factory)
)
@@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object):
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- @defer.inlineCallbacks
- def _resolve_server(self):
+ async def _resolve_server(self) -> List[Server]:
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.
-
- Returns:
- Deferred[list[Server]]
"""
if self._parsed_uri.scheme != b"matrix":
@@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object):
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
- server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+ server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
return server_list
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 021b233a7d..2ede90a9b1 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -17,10 +17,10 @@
import logging
import random
import time
+from typing import List
import attr
-from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
@@ -113,16 +113,14 @@ class SrvResolver(object):
self._cache = cache
self._get_time = get_time
- @defer.inlineCallbacks
- def resolve_service(self, service_name):
+ async def resolve_service(self, service_name: bytes) -> List[Server]:
"""Look up a SRV record
Args:
service_name (bytes): record to look up
Returns:
- Deferred[list[Server]]:
- a list of the SRV records, or an empty list if none found
+ a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
@@ -136,7 +134,7 @@ class SrvResolver(object):
return _sort_server_list(servers)
try:
- answers, _, _ = yield make_deferred_yieldable(
+ answers, _, _ = await make_deferred_yieldable(
self._dns_client.lookupService(service_name)
)
except DNSNameError:
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 954e059e76..69945a8f98 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+ async def resolve_service(_):
+ return result
+
+ return resolve_service
+
+
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
@@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when there is a single SRV record
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self._handle_well_known_connection(
client_factory,
@@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
+ )
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
-
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"target.com", port=8443),
- Server(host=b"target.com", port=8444),
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ )
self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
-
- self.assertNoResult(resolve_d)
-
- # should have reset to the sentinel context
- self.assertIs(current_context(), SENTINEL_CONTEXT)
-
- result = yield resolve_d
+ result = yield defer.ensureDeferred(resolve_d)
# should have restored our context
self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertFalse(dns_client_mock.lookupService.called)
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
- yield resolver.resolve_service(service_name)
+ yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
lookup_deferred.callback(
(
|