summary refs log tree commit diff
path: root/synapse/http/federation/srv_resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/federation/srv_resolver.py')
-rw-r--r--synapse/http/federation/srv_resolver.py35
1 files changed, 32 insertions, 3 deletions
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..bbda0a23f4 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 SERVER_CACHE = {}
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class Server(object):
     """
     Our record of an individual server which can be tried to reach a destination.
@@ -83,6 +83,35 @@ def pick_server_from_list(server_list):
     raise RuntimeError("pick_server_from_list got to end of eligible server list.")
 
 
+def _sort_server_list(server_list):
+    """Given a list of SRV records sort them into priority order and shuffle
+    each priority with the given weight.
+    """
+    priority_map = {}
+
+    for server in server_list:
+        priority_map.setdefault(server.priority, []).append(server)
+
+    results = []
+    for priority in sorted(priority_map):
+        servers = priority_map.pop(priority)
+
+        while servers:
+            total_weight = sum(s.weight for s in servers)
+            target_weight = random.randint(0, total_weight)
+
+            for s in servers:
+                target_weight -= s.weight
+
+                if target_weight <= 0:
+                    break
+
+            results.append(s)
+            servers.remove(s)
+
+    return results
+
+
 class SrvResolver(object):
     """Interface to the dns client to do SRV lookups, with result caching.
 
@@ -120,7 +149,7 @@ class SrvResolver(object):
         if cache_entry:
             if all(s.expires > now for s in cache_entry):
                 servers = list(cache_entry)
-                return servers
+                return _sort_server_list(servers)
 
         try:
             answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +198,4 @@ class SrvResolver(object):
             )
 
         self._cache[service_name] = list(servers)
-        return servers
+        return _sort_server_list(servers)