diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b8ed4ec905..f68646fd0d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
import logging
import random
import time
-from typing import List
+from typing import Callable, Dict, List
import attr
@@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
-SERVER_CACHE = {}
+SERVER_CACHE: Dict[bytes, List["Server"]] = {}
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
class Server:
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
- host (bytes): target hostname
- port (int):
- priority (int):
- weight (int):
- expires (int): when the cache should expire this record - in *seconds* since
+ host: target hostname
+ port:
+ priority:
+ weight:
+ expires: when the cache should expire this record - in *seconds* since
the epoch
"""
- host = attr.ib()
- port = attr.ib()
- priority = attr.ib(default=0)
- weight = attr.ib(default=0)
- expires = attr.ib(default=0)
+ host: bytes
+ port: int
+ priority: int = 0
+ weight: int = 0
+ expires: int = 0
-def _sort_server_list(server_list):
+def _sort_server_list(server_list: List[Server]) -> List[Server]:
"""Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight.
"""
- priority_map = {}
+ priority_map: Dict[int, List[Server]] = {}
for server in server_list:
priority_map.setdefault(server.priority, []).append(server)
@@ -103,11 +103,16 @@ class SrvResolver:
Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
- cache (dict): cache object
- get_time (callable): clock implementation. Should return seconds since the epoch
+ cache: cache object
+ get_time: clock implementation. Should return seconds since the epoch
"""
- def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ def __init__(
+ self,
+ dns_client=client,
+ cache: Dict[bytes, List[Server]] = SERVER_CACHE,
+ get_time: Callable[[], float] = time.time,
+ ):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
@@ -116,7 +121,7 @@ class SrvResolver:
"""Look up a SRV record
Args:
- service_name (bytes): record to look up
+ service_name: record to look up
Returns:
a list of the SRV records, or an empty list if none found
@@ -158,7 +163,7 @@ class SrvResolver:
and answers[0].payload
and answers[0].payload.target == dns.Name(b".")
):
- raise ConnectError("Service %s unavailable" % service_name)
+ raise ConnectError(f"Service {service_name!r} unavailable")
servers = []
|