summary refs log tree commit diff
path: root/tests/http
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http')
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py60
1 files changed, 24 insertions, 36 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 6a0b5fc0bd..0d17f2fe5b 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -14,8 +14,8 @@
 import base64
 import logging
 import os
-from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
-from unittest.mock import Mock, patch
+from typing import Generator, List, Optional, cast
+from unittest.mock import AsyncMock, patch
 
 import treq
 from netaddr import IPSet
@@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
-from synapse.http.federation.srv_resolver import Server
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import (
     WELL_KNOWN_MAX_SIZE,
     WellKnownResolver,
@@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config
 logger = logging.getLogger(__name__)
 
 
-# Once Async Mocks or lambdas are supported this can go away.
-def generate_resolve_service(
-    result: List[Server],
-) -> Callable[[Any], Awaitable[List[Server]]]:
-    async def resolve_service(_: Any) -> List[Server]:
-        return result
-
-    return resolve_service
-
-
 class MatrixFederationAgentTests(unittest.TestCase):
     def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
 
-        self.mock_resolver = Mock()
+        self.mock_resolver = AsyncMock(spec=SrvResolver)
 
         config_dict = default_config("test", parse=False)
         config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@@ -636,7 +626,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv1"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar")
@@ -722,7 +712,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@@ -776,7 +766,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the .well-known delegates elsewhere"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -840,7 +830,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -930,7 +920,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@@ -986,7 +976,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # the config left to the default, which will not trust it (since the
         # presented cert is signed by a test CA)
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         config = default_config("test", parse=True)
@@ -1037,9 +1027,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"srvtarget", port=8443)]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"srvtarget", port=8443)
+        ]
         self.reactor.lookups["srvtarget"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@@ -1094,9 +1084,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"srvtarget", port=8443)]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"srvtarget", port=8443)
+        ]
 
         self._handle_well_known_connection(
             client_factory,
@@ -1137,7 +1127,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """test the behaviour when the server name has idna chars in"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
 
         # the resolver is always called with the IDNA hostname as a native string.
         self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -1201,9 +1191,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """test the behaviour when the target of a SRV record has idna chars"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"xn--trget-3qa.com", port=8443)]  # târget.com
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"xn--trget-3qa.com", port=8443)
+        ]  # târget.com
         self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(
@@ -1407,12 +1397,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test that other SRV results are tried if the first one fails."""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [
-                Server(host=b"target.com", port=8443),
-                Server(host=b"target.com", port=8444),
-            ]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"target.com", port=8443),
+            Server(host=b"target.com", port=8444),
+        ]
         self.reactor.lookups["target.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")