summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r--synapse/http/endpoint.py50
1 files changed, 26 insertions, 24 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 564ae4c10d..87a482650d 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet import defer, reactor
 from twisted.internet.error import ConnectError
@@ -30,7 +29,10 @@ logger = logging.getLogger(__name__)
 
 SERVER_CACHE = {}
 
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
 _Server = collections.namedtuple(
     "_Server", "priority weight host port expires"
 )
@@ -224,9 +226,10 @@ class SRVClientEndpoint(object):
                 return self.default_server
             else:
                 raise ConnectError(
-                    "Not server available for %s" % self.service_name
+                    "No server available for %s" % self.service_name
                 )
 
+        # look for all servers with the same priority
         min_priority = self.servers[0].priority
         weight_indexes = list(
             (index, server.weight + 1)
@@ -236,11 +239,22 @@ class SRVClientEndpoint(object):
 
         total_weight = sum(weight for index, weight in weight_indexes)
         target_weight = random.randint(0, total_weight)
-
         for index, weight in weight_indexes:
             target_weight -= weight
             if target_weight <= 0:
                 server = self.servers[index]
+                # XXX: this looks totally dubious:
+                #
+                # (a) we never reuse a server until we have been through
+                #     all of the servers at the same priority, so if the
+                #     weights are A: 100, B:1, we always do ABABAB instead of
+                #     AAAA...AAAB (approximately).
+                #
+                # (b) After using all the servers at the lowest priority,
+                #     we move onto the next priority. We should only use the
+                #     second priority if servers at the top priority are
+                #     unreachable.
+                #
                 del self.servers[index]
                 self.used_servers.append(server)
                 return server
@@ -277,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
         if (len(answers) == 1
                 and answers[0].type == dns.SRV
                 and answers[0].payload
-                and answers[0].payload.target == dns.Name('.')):
+                and answers[0].payload.target == dns.Name(b'.')):
             raise ConnectError("Service %s unavailable" % service_name)
 
         for answer in answers:
@@ -285,26 +299,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
                 continue
 
             payload = answer.payload
-            host = str(payload.target)
-            srv_ttl = answer.ttl
-
-            try:
-                answers, _, _ = yield dns_client.lookupAddress(host)
-            except DNSNameError:
-                continue
 
-            for answer in answers:
-                if answer.type == dns.A and answer.payload:
-                    ip = answer.payload.dottedQuad()
-                    host_ttl = min(srv_ttl, answer.ttl)
-
-                    servers.append(_Server(
-                        host=ip,
-                        port=int(payload.port),
-                        priority=int(payload.priority),
-                        weight=int(payload.weight),
-                        expires=int(clock.time()) + host_ttl,
-                    ))
+            servers.append(_Server(
+                host=str(payload.target),
+                port=int(payload.port),
+                priority=int(payload.priority),
+                weight=int(payload.weight),
+                expires=int(clock.time()) + answer.ttl,
+            ))
 
         servers.sort()
         cache[service_name] = list(servers)