summary refs log tree commit diff
path: root/tests/http/federation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/federation')
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py51
-rw-r--r--tests/http/federation/test_srv_resolver.py26
2 files changed, 40 insertions, 37 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 954e059e76..69945a8f98 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
     return test_server_connection_factory
 
 
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+    async def resolve_service(_):
+        return result
+
+    return resolve_service
+
+
 class MatrixFederationAgentTests(unittest.TestCase):
     def setUp(self):
         self.reactor = ThreadedMemoryReactorClock()
@@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         Test the behaviour when the certificate on the server doesn't match the hostname
         """
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv1"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         Test the behaviour when the server name has no port, no SRV, and no well-known
         """
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the .well-known delegates elsewhere
         """
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the server name has no port and no SRV record, but
         the .well-known has a 300 redirect
         """
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         Test the behaviour when the server name has an *invalid* well-known (and no SRV)
         """
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # the config left to the default, which will not trust it (since the
         # presented cert is signed by a test CA)
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         config = default_config("test", parse=True)
@@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         Test the behaviour when there is a single SRV record
         """
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"srvtarget", port=8443)
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [Server(host=b"srvtarget", port=8443)]
+        )
         self.reactor.lookups["srvtarget"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"srvtarget", port=8443)
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [Server(host=b"srvtarget", port=8443)]
+        )
 
         self._handle_well_known_connection(
             client_factory,
@@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_idna_servername(self):
         """test the behaviour when the server name has idna chars in"""
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
 
         # the resolver is always called with the IDNA hostname as a native string.
         self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_idna_srv_target(self):
         """test the behaviour when the target of a SRV record has idna chars"""
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"xn--trget-3qa.com", port=8443)  # târget.com
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [Server(host=b"xn--trget-3qa.com", port=8443)]  # târget.com
+        )
         self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_srv_fallbacks(self):
         """Test that other SRV results are tried if the first one fails.
         """
-
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"target.com", port=8443),
-            Server(host=b"target.com", port=8444),
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [
+                Server(host=b"target.com", port=8443),
+                Server(host=b"target.com", port=8444),
+            ]
+        )
         self.reactor.lookups["target.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
 from twisted.names import dns, error
 
 from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
 
 from tests import unittest
 from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
 
             with LoggingContext("one") as ctx:
                 resolve_d = resolver.resolve_service(service_name)
-
-                self.assertNoResult(resolve_d)
-
-                # should have reset to the sentinel context
-                self.assertIs(current_context(), SENTINEL_CONTEXT)
-
-                result = yield resolve_d
+                result = yield defer.ensureDeferred(resolve_d)
 
                 # should have restored our context
                 self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {service_name: [entry]}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield resolver.resolve_service(service_name)
+        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
 
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
             dns_client=dns_client_mock, cache=cache, get_time=clock.time
         )
 
-        servers = yield resolver.resolve_service(service_name)
+        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
         self.assertFalse(dns_client_mock.lookupService.called)
 
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         with self.assertRaises(error.DNSServerError):
-            yield resolver.resolve_service(service_name)
+            yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
     @defer.inlineCallbacks
     def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield resolver.resolve_service(service_name)
+        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
         self.assertEquals(len(servers), 0)
         self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        resolve_d = resolver.resolve_service(service_name)
-        self.assertNoResult(resolve_d)
+        # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+        resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
 
         # returning a single "." should make the lookup fail with a ConenctError
         lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        resolve_d = resolver.resolve_service(service_name)
-        self.assertNoResult(resolve_d)
+        # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+        resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
 
         lookup_deferred.callback(
             (