summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-01-22 17:35:09 +0000
committerRichard van der Hoff <richard@matrix.org>2019-01-22 20:35:12 +0000
commit53a327b4d5bdcab36651aa86ddf1815ff86e5db2 (patch)
treee10bb401e66216757ed73fe3c318547fea24bc26
parentchangelog (diff)
downloadsynapse-53a327b4d5bdcab36651aa86ddf1815ff86e5db2.tar.xz
Require that service_name be a byte string
it is only ever a bytes now, so let's enforce that.
-rw-r--r--synapse/http/federation/srv_resolver.py8
-rw-r--r--tests/http/federation/test_srv_resolver.py8
2 files changed, 8 insertions, 8 deletions
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index a6e92fdf40..e05f934d0b 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -92,7 +92,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
     but the cache never gets populated), so we add our own caching layer here.
 
     Args:
-        service_name (unicode|bytes): record to look up
+        service_name (bytes): record to look up
         dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
         cache (dict): cache object
         clock (object): clock implementation. must provide a time() method.
@@ -100,9 +100,9 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
     Returns:
         Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
     """
-    # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
-    # byteses; however they will obviously end up as separate entries in the cache. We
-    # should pick one form and stick with it.
+    if not isinstance(service_name, bytes):
+        raise TypeError("%r is not a byte string" % (service_name,))
+
     cache_entry = cache.get(service_name, None)
     if cache_entry:
         if all(s.expires > int(clock.time()) for s in cache_entry):
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 1271a495e1..de4d0089c8 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -83,7 +83,7 @@ class SrvResolverTestCase(unittest.TestCase):
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
 
-        service_name = "test_service.example.com"
+        service_name = b"test_service.example.com"
 
         entry = Mock(spec_set=["expires"])
         entry.expires = 0
@@ -106,7 +106,7 @@ class SrvResolverTestCase(unittest.TestCase):
         dns_client_mock = Mock(spec_set=['lookupService'])
         dns_client_mock.lookupService = Mock(spec_set=[])
 
-        service_name = "test_service.example.com"
+        service_name = b"test_service.example.com"
 
         entry = Mock(spec_set=["expires"])
         entry.expires = 999999999
@@ -128,7 +128,7 @@ class SrvResolverTestCase(unittest.TestCase):
 
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
 
-        service_name = "test_service.example.com"
+        service_name = b"test_service.example.com"
 
         cache = {}
 
@@ -141,7 +141,7 @@ class SrvResolverTestCase(unittest.TestCase):
 
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
 
-        service_name = "test_service.example.com"
+        service_name = b"test_service.example.com"
 
         cache = {}