summary refs log tree commit diff
path: root/tests/http/federation/test_srv_resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/federation/test_srv_resolver.py')
-rw-r--r--tests/http/federation/test_srv_resolver.py60
1 files changed, 36 insertions, 24 deletions
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 77ce8432ac..7748f56ee6 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Dict, Generator, List, Tuple, cast
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred
 from twisted.internet.error import ConnectError
 from twisted.names import dns, error
 
-from synapse.http.federation.srv_resolver import SrvResolver
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.logging.context import LoggingContext, current_context
 
 from tests import unittest
@@ -28,7 +28,7 @@ from tests.utils import MockClock
 
 
 class SrvResolverTestCase(unittest.TestCase):
-    def test_resolve(self):
+    def test_resolve(self) -> None:
         dns_client_mock = Mock()
 
         service_name = b"test_service.example.com"
@@ -38,18 +38,19 @@ class SrvResolverTestCase(unittest.TestCase):
             type=dns.SRV, payload=dns.Record_SRV(target=host_name)
         )
 
-        result_deferred = Deferred()
+        result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
         dns_client_mock.lookupService.return_value = result_deferred
 
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         @defer.inlineCallbacks
-        def do_lookup():
+        def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
 
             with LoggingContext("one") as ctx:
                 resolve_d = resolver.resolve_service(service_name)
-                result = yield defer.ensureDeferred(resolve_d)
+                result: List[Server]
+                result = yield defer.ensureDeferred(resolve_d)  # type: ignore[assignment]
 
                 # should have restored our context
                 self.assertIs(current_context(), ctx)
@@ -70,7 +71,9 @@ class SrvResolverTestCase(unittest.TestCase):
         self.assertEqual(servers[0].host, host_name)
 
     @defer.inlineCallbacks
-    def test_from_cache_expired_and_dns_fail(self):
+    def test_from_cache_expired_and_dns_fail(
+        self,
+    ) -> Generator["Deferred[object]", object, None]:
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
 
@@ -81,10 +84,13 @@ class SrvResolverTestCase(unittest.TestCase):
         entry.priority = 0
         entry.weight = 0
 
-        cache = {service_name: [entry]}
+        cache = {service_name: [cast(Server, entry)]}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+        servers: List[Server]
+        servers = yield defer.ensureDeferred(
+            resolver.resolve_service(service_name)
+        )  # type: ignore[assignment]
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
 
@@ -92,7 +98,7 @@ class SrvResolverTestCase(unittest.TestCase):
         self.assertEqual(servers, cache[service_name])
 
     @defer.inlineCallbacks
-    def test_from_cache(self):
+    def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
         clock = MockClock()
 
         dns_client_mock = Mock(spec_set=["lookupService"])
@@ -105,12 +111,15 @@ class SrvResolverTestCase(unittest.TestCase):
         entry.priority = 0
         entry.weight = 0
 
-        cache = {service_name: [entry]}
+        cache = {service_name: [cast(Server, entry)]}
         resolver = SrvResolver(
             dns_client=dns_client_mock, cache=cache, get_time=clock.time
         )
 
-        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+        servers: List[Server]
+        servers = yield defer.ensureDeferred(
+            resolver.resolve_service(service_name)
+        )  # type: ignore[assignment]
 
         self.assertFalse(dns_client_mock.lookupService.called)
 
@@ -118,45 +127,48 @@ class SrvResolverTestCase(unittest.TestCase):
         self.assertEqual(servers, cache[service_name])
 
     @defer.inlineCallbacks
-    def test_empty_cache(self):
+    def test_empty_cache(self) -> Generator["Deferred[object]", object, None]:
         dns_client_mock = Mock()
 
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
 
         service_name = b"test_service.example.com"
 
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         with self.assertRaises(error.DNSServerError):
             yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
     @defer.inlineCallbacks
-    def test_name_error(self):
+    def test_name_error(self) -> Generator["Deferred[object]", object, None]:
         dns_client_mock = Mock()
 
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
 
         service_name = b"test_service.example.com"
 
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+        servers: List[Server]
+        servers = yield defer.ensureDeferred(
+            resolver.resolve_service(service_name)
+        )  # type: ignore[assignment]
 
         self.assertEqual(len(servers), 0)
         self.assertEqual(len(cache), 0)
 
-    def test_disabled_service(self):
+    def test_disabled_service(self) -> None:
         """
         test the behaviour when there is a single record which is ".".
         """
         service_name = b"test_service.example.com"
 
-        lookup_deferred = Deferred()
+        lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = lookup_deferred
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
@@ -173,16 +185,16 @@ class SrvResolverTestCase(unittest.TestCase):
 
         self.failureResultOf(resolve_d, ConnectError)
 
-    def test_non_srv_answer(self):
+    def test_non_srv_answer(self) -> None:
         """
         test the behaviour when the dns server gives us a spurious non-SRV response
         """
         service_name = b"test_service.example.com"
 
-        lookup_deferred = Deferred()
+        lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = lookup_deferred
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         # Old versions of Twisted don't have an ensureDeferred in successResultOf.