diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 1c3b7ea28a..815f8ff2f7 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,29 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import logging
import random
import re
-import time
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-logger = logging.getLogger(__name__)
-
-SERVER_CACHE = {}
+from synapse.http.federation.srv_resolver import Server, resolve_service
-# our record of an individual server which can be tried to reach a destination.
-#
-# "host" is the hostname acquired from the SRV record. Except when there's
-# no SRV record, in which case it is the original hostname.
-_Server = collections.namedtuple(
- "_Server", "priority weight host port expires"
-)
+logger = logging.getLogger(__name__)
def parse_server_name(server_name):
@@ -165,12 +153,9 @@ class SRVClientEndpoint(object):
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None:
- self.default_server = _Server(
+ self.default_server = Server(
host=domain,
port=default_port,
- priority=0,
- weight=0,
- expires=0,
)
else:
self.default_server = None
@@ -240,57 +225,3 @@ class SRVClientEndpoint(object):
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
-
-
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- servers = []
-
- try:
- try:
- answers, _, _ = yield dns_client.lookupService(service_name)
- except DNSNameError:
- defer.returnValue([])
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(_Server(
- host=str(payload.target),
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + answer.ttl,
- ))
-
- servers.sort()
- cache[service_name] = list(servers)
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
- )
- servers = list(cache_entry)
- else:
- raise e
-
- defer.returnValue(servers)
|