summary refs log tree commit diff
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2021-08-18 19:53:20 +0200
committerGitHub <noreply@github.com>2021-08-18 13:53:20 -0400
commit0c3565da4cdbe53646ae0bc737900526a1d3df67 (patch)
tree2daef7c6b7db3a336f7b7dd187981ac46f68dbec
parentMerge branch 'release-v1.41' into develop (diff)
downloadsynapse-0c3565da4cdbe53646ae0bc737900526a1d3df67.tar.xz
Additional type hints for the proxy agent and SRV resolver modules. (#10608)
-rw-r--r--changelog.d/10608.misc1
-rw-r--r--mypy.ini3
-rw-r--r--synapse/http/additional_resource.py13
-rw-r--r--synapse/http/federation/srv_resolver.py45
-rw-r--r--synapse/http/proxyagent.py4
5 files changed, 41 insertions, 25 deletions
diff --git a/changelog.d/10608.misc b/changelog.d/10608.misc
new file mode 100644
index 0000000000..875bdd2fd0
--- /dev/null
+++ b/changelog.d/10608.misc
@@ -0,0 +1 @@
+Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index e1b9405daa..107f4de76c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -28,10 +28,13 @@ files =
   synapse/federation,
   synapse/groups,
   synapse/handlers,
+  synapse/http/additional_resource.py,
   synapse/http/client.py,
   synapse/http/federation/matrix_federation_agent.py,
+  synapse/http/federation/srv_resolver.py,
   synapse/http/federation/well_known_resolver.py,
   synapse/http/matrixfederationclient.py,
+  synapse/http/proxyagent.py,
   synapse/http/servlet.py,
   synapse/http/server.py,
   synapse/http/site.py,
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 55ea97a07f..9a2684aca4 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,8 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
 from synapse.http.server import DirectServeJsonResource
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class AdditionalResource(DirectServeJsonResource):
     """Resource wrapper for additional_resources
@@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
     and exception handling.
     """
 
-    def __init__(self, hs, handler):
+    def __init__(self, hs: "HomeServer", handler):
         """Initialise AdditionalResource
 
         The ``handler`` should return a deferred which completes when it has
@@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
         ``request.write()``, and call ``request.finish()``.
 
         Args:
-            hs (synapse.server.HomeServer): homeserver
+            hs: homeserver
             handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
                 function to be called to handle the request.
         """
         super().__init__()
         self._handler = handler
 
-    def _async_render(self, request):
+    def _async_render(self, request: Request):
         # Cheekily pass the result straight through, so we don't need to worry
         # if its an awaitable or not.
         return self._handler(request)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b8ed4ec905..f68646fd0d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import time
-from typing import List
+from typing import Callable, Dict, List
 
 import attr
 
@@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
 
 logger = logging.getLogger(__name__)
 
-SERVER_CACHE = {}
+SERVER_CACHE: Dict[bytes, List["Server"]] = {}
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
 class Server:
     """
     Our record of an individual server which can be tried to reach a destination.
 
     Attributes:
-        host (bytes): target hostname
-        port (int):
-        priority (int):
-        weight (int):
-        expires (int): when the cache should expire this record - in *seconds* since
+        host: target hostname
+        port:
+        priority:
+        weight:
+        expires: when the cache should expire this record - in *seconds* since
             the epoch
     """
 
-    host = attr.ib()
-    port = attr.ib()
-    priority = attr.ib(default=0)
-    weight = attr.ib(default=0)
-    expires = attr.ib(default=0)
+    host: bytes
+    port: int
+    priority: int = 0
+    weight: int = 0
+    expires: int = 0
 
 
-def _sort_server_list(server_list):
+def _sort_server_list(server_list: List[Server]) -> List[Server]:
     """Given a list of SRV records sort them into priority order and shuffle
     each priority with the given weight.
     """
-    priority_map = {}
+    priority_map: Dict[int, List[Server]] = {}
 
     for server in server_list:
         priority_map.setdefault(server.priority, []).append(server)
@@ -103,11 +103,16 @@ class SrvResolver:
 
     Args:
         dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
-        cache (dict): cache object
-        get_time (callable): clock implementation. Should return seconds since the epoch
+        cache: cache object
+        get_time: clock implementation. Should return seconds since the epoch
     """
 
-    def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+    def __init__(
+        self,
+        dns_client=client,
+        cache: Dict[bytes, List[Server]] = SERVER_CACHE,
+        get_time: Callable[[], float] = time.time,
+    ):
         self._dns_client = dns_client
         self._cache = cache
         self._get_time = get_time
@@ -116,7 +121,7 @@ class SrvResolver:
         """Look up a SRV record
 
         Args:
-            service_name (bytes): record to look up
+            service_name: record to look up
 
         Returns:
             a list of the SRV records, or an empty list if none found
@@ -158,7 +163,7 @@ class SrvResolver:
             and answers[0].payload
             and answers[0].payload.target == dns.Name(b".")
         ):
-            raise ConnectError("Service %s unavailable" % service_name)
+            raise ConnectError(f"Service {service_name!r} unavailable")
 
         servers = []
 
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a3f31452d0..6fd88bde20 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
             raise ValueError(f"Invalid URI {uri!r}")
 
         parsed_uri = URI.fromBytes(uri)
-        pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+        pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
         request_path = parsed_uri.originForm
 
         should_skip_proxy = False
@@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
                 )
             # Cache *all* connections under the same key, since we are only
             # connecting to a single destination, the proxy:
-            pool_key = ("http-proxy", self.http_proxy_endpoint)
+            pool_key = "http-proxy"
             endpoint = self.http_proxy_endpoint
             request_path = uri
         elif (