diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..3fe4ffb9e5 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-@attr.s
+@attr.s(slots=True, frozen=True)
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
@@ -53,34 +53,47 @@ class Server(object):
expires = attr.ib(default=0)
-def pick_server_from_list(server_list):
- """Randomly choose a server from the server list
+def _sort_server_list(server_list):
+ """Given a list of SRV records sort them into priority order and shuffle
+ each priority with the given weight.
+ """
+ priority_map = {}
- Args:
- server_list (list[Server]): list of candidate servers
+ for server in server_list:
+ priority_map.setdefault(server.priority, []).append(server)
- Returns:
- Tuple[bytes, int]: (host, port) pair for the chosen server
- """
- if not server_list:
- raise RuntimeError("pick_server_from_list called with empty list")
+ results = []
+ for priority in sorted(priority_map):
+ servers = priority_map[priority]
+
+ # This algorithms roughly follows the algorithm described in RFC2782,
+ # changed to remove an off-by-one error.
+ #
+ # N.B. Weights can be zero, which means that they should be picked
+ # rarely.
+
+ total_weight = sum(s.weight for s in servers)
+
+ # Total weight can become zero if there are only zero weight servers
+ # left, which we handle by just shuffling and appending to the results.
+ while servers and total_weight:
+ target_weight = random.randint(1, total_weight)
- # TODO: currently we only use the lowest-priority servers. We should maintain a
- # cache of servers known to be "down" and filter them out
+ for s in servers:
+ target_weight -= s.weight
- min_priority = min(s.priority for s in server_list)
- eligible_servers = list(s for s in server_list if s.priority == min_priority)
- total_weight = sum(s.weight for s in eligible_servers)
- target_weight = random.randint(0, total_weight)
+ if target_weight <= 0:
+ break
- for s in eligible_servers:
- target_weight -= s.weight
+ results.append(s)
+ servers.remove(s)
+ total_weight -= s.weight
- if target_weight <= 0:
- return s.host, s.port
+ if servers:
+ random.shuffle(servers)
+ results.extend(servers)
- # this should be impossible.
- raise RuntimeError("pick_server_from_list got to end of eligible server list.")
+ return results
class SrvResolver(object):
@@ -120,7 +133,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- return servers
+ return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +182,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
- return servers
+ return _sort_server_list(servers)
|